#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author: w8ay
# @Date:   2017年12月19日 12:04:55
import os,imp,time,threading,socket
from lib.core.data import paths
from thirdparty import miniCurl
from lib.utils import until
from lib.core.data import urlconfig,logger,w9config
from thirdparty import hackhttp
from thirdparty.ThreadPool import w8_threadpool
from lib.core.common import printMessage
from lib.core.data import w9_hash_pycode
from lib.core.settings import LIST_PLUGINS
from lib.core.exception import ToolkitUserQuitException
from lib.core.exception import ToolkitMissingPrivileges
from lib.core.exception import ToolkitSystemException
from lib.core.outhtml import buildHtml

class Exploit_run(object):

    def __init__(self,threadNum = 15):
        self.hash_pycode_Lists = {}
        self.lock_result = threading.Lock()
        self.task_result = []
        self.lock_output = threading.Lock()
        self.table_exception = set()
        
        remove_plugins = list(set(LIST_PLUGINS).difference(set(urlconfig.diyPlugin)))

        filter_func = lambda file: (True, False)['__init__' in file or 'pyc' in file]
        def getExp():
            direxp = []
            for dirpath, _, filenames in os.walk(paths.w9scan_Plugin_Path):
                for filename in filenames:
                    if filename.strip('.py') not in remove_plugins:
                        direxp.append(os.path.join(dirpath,filename))
            return direxp
        dir_exploit = filter(filter_func,getExp())

        self._TargetScanAnge = {'target': urlconfig.url,
                                'scanport': urlconfig.scanport,
                                'find_service':urlconfig.find_service
                                }

        try:
            for exp in dir_exploit:
                with open(exp, 'rb') as f:
                    reads = str(f.read())
                    f.close()
                    self.hash_pycode_Lists.setdefault(os.path.basename(exp), reads)
        except Exception as error_info:
            raise ToolkitMissingPrivileges(error_info)

        self.buildHtml = buildHtml()
        self._print('Fetch %d new plugins' % len(self.hash_pycode_Lists))
        self.th = w8_threadpool(threadNum,self._work,urlconfig.mutiurl)
        logger.info('Set threadnum:%d'%threadNum)
        self.url = ""

    def setCurrentUrl(self,url):
        self.url = url

    def init_spider(self):
        for k, v in self.hash_pycode_Lists.iteritems():
            pluginObj = self._load_module(v)
            pluginObj.task_push = self.task_push
            pluginObj.curl = miniCurl.Curl()
            pluginObj.security_note = self._security_note
            pluginObj.security_info = self._security_info
            pluginObj.security_warning = self._security_warning
            pluginObj.security_hole = self._security_hole
            pluginObj.security_set = self._security_set
            pluginObj.debug = self._debug
            pluginObj.util = until
            pluginObj._G = self._TargetScanAnge
            pluginObj.hackhttp = hackhttp.hackhttp()
            pluginObj.ThreadPool = w8_threadpool

            try:
                pluginObj_tuple = pluginObj.assign("spider_file", "")
                if not isinstance(pluginObj_tuple, tuple):  # 判断是否是元组
                    pluginObj_tuple = pluginObj.assign("spider_end", "")
                    if not isinstance(pluginObj_tuple, tuple):
                        continue
                bool_value, agrs = pluginObj_tuple[0], pluginObj_tuple[1]
                if bool_value:
                    pconf = {}
                    pconf["pluginObj"] = pluginObj
                    pconf["service"] = "spider_file"
                    w9_hash_pycode.setdefault(k,pconf)
            except Exception as err_info:
                raise ToolkitMissingPrivileges("load spider plugins error! " + err_info)


    def _load_module(self,chunk,name='<w9scan>'):
        try:
            pluginObj = imp.new_module(str(name))
            exec chunk in pluginObj.__dict__
        except Exception:
            raise ToolkitMissingPrivileges("Load Module excepting")
        return pluginObj
    
    def load_modules(self,service,url):
        # 内部载入所有模块，并且判断服务名是否正确
        
        for k, v in self.hash_pycode_Lists.iteritems():
            try:
                pluginObj = self._load_module(v)
                pluginObj.task_push = self.task_push
                pluginObj.curl = miniCurl.Curl()
                pluginObj.security_note = self._security_note
                pluginObj.security_info = self._security_info
                pluginObj.security_warning = self._security_warning
                pluginObj.security_hole = self._security_hole
                pluginObj.security_set = self._security_set
                pluginObj.debug = self._debug
                pluginObj.util = until
                pluginObj._G = self._TargetScanAnge
                pluginObj.ThreadPool = w8_threadpool

                if w9config.TimeOut is None:
                    w9config.TimeOut = 10
                if w9config.Cookie is None:
                    w9config.Cookie = ""
                conpool = hackhttp.httpconpool(20,timeout=w9config.TimeOut)
                pluginObj.hackhttp = hackhttp.hackhttp(conpool=conpool,cookie_str=w9config.Cookie,user_agent = w9config.UserAgent,headers=w9config.headers)
                
                pluginObj_tuple = pluginObj.assign(service, url)
                if not isinstance(pluginObj_tuple, tuple):  # 判断是否是元组
                    continue
                bool_value, agrs = pluginObj_tuple[0], pluginObj_tuple[1]
                if bool_value:
                    threadConf = dict()
                    threadConf["filename"] = k
                    threadConf["service"] = service
                    threadConf["agrs"] = agrs
                    threadConf["pluginObj"] = pluginObj
                    self._print(
                        "load plugin %s for service '%s'" % (threadConf["filename"], threadConf["service"]))
                    self.th.push(threadConf)
            except Exception as err_info:
                logger.error("load plugin error:%s service:%s filename:%s"%(err_info,service,k))

    def run(self):
        self.th.run()

    def _work(self,threadConf):
        # 程序内部工作线程
        self._print("running plugin %s for service '%s'" % (threadConf["filename"], threadConf["service"]))
        try:
            pluginObj = threadConf["pluginObj"]
            pluginObj.audit(threadConf["agrs"])
        except socket.timeout:
            self.lock_result.acquire()
            if threadConf["filename"] not in self.table_exception:
                self.th.push(threadConf)
                logger.warning("The plugin [name:%s service:%s] runs out of time and is retrying the queue"%(threadConf["filename"],threadConf["service"]))
            else:
                logger.warning("The plugin [name:%s service:%s] retries failed."%(threadConf["filename"],threadConf["service"]))
            self.lock_result.release()
            
        except Exception as error_info:
            logger.error("Running plugin error:%s service:%s filename:%s"%(error_info,threadConf["service"], threadConf["filename"]))

    def _security_note(self, body, k=''):
        self.lock_output.acquire()
        self.buildHtml.add_list("note",body,k,self.url)
        logger.security_note(body,k)
        self.lock_output.release()

    def _security_info(self, body, k=''):
        self.lock_output.acquire()
        self.buildHtml.add_list("info",body,k,self.url)
        logger.security_info(body,k)
        self.lock_output.release()

    def _security_warning(self, body, k=''):
        self.lock_output.acquire()
        self.buildHtml.add_list("warning",body,k,self.url)
        logger.security_warning(body,k)
        self.lock_output.release()

    def _security_hole(self, body, k=''):
        self.lock_output.acquire()
        self.buildHtml.add_list("hole",body,k,self.url)
        logger.security_hole(body,k)
        self.lock_output.release()

    def _security_set(self,level,body,k=''):
        self.buildHtml.add_set(level,body,k,self.url)

    def _debug(self, fmt, *args):
        if len(args) >= 3:
            self._print(fmt % args)

    def task_push(self, serviceType, target_info, uuid=None, target=None, pr=-1):
        self.load_modules(serviceType,target_info)

    def _print(self,*args):
        # fix Typerror bug
        self.lock_output.acquire()
        logger.debug(u''.join([str(i) for i in args]))
        self.lock_output.release()
    
    def report(self):
        logger.info("Prepare for building result...")
        if urlconfig.mutiurl:
            self.buildHtml.mutiBuild()
        else:
            self.buildHtml.build()